library(here)
Warning message:
In do_once((if (is_R_CMD_check()) stop else warning)("The function xfun::isFALSE() will be deprecated in the future. Please ",  :
  The function xfun::isFALSE() will be deprecated in the future. Please consider using base::isFALSE(x) or identical(x, FALSE) instead.
library(cowplot)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))
all_models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>% 
  str_split("diagnoses_|_icd|.csv") %>% 
  sapply(., function(x) x[2]) %>% 
  unique()
all_models
[1] "claude-3-haiku-20240307_t1-0"       "claude-3-opus-20240229_t1-0"       
[3] "gemini-1.0-pro-002_t1-0"            "gemini-1.5-flash-preview-0514_t1-0"
[5] "gemini-1.5-pro-001_t1-0"            "gpt-3.5-turbo-1106"                
[7] "gpt-4-turbo-preview"               

Import data

df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)

Rank abundance

Original responses

rank_abundance_plot(df_gpt3.5)+ggtitle("ChatGPT 3.5")

rank_abundance_plot(df_gpt4.0)+ggtitle("ChatGPT 4.0")

rank_abundance_plot(df_claude3_haiku_t1.0)+ggtitle("Claude3 Haiku t1.0")

rank_abundance_plot(df_claude3_opus_t1.0)+ggtitle("Claude3 Opus")

rank_abundance_plot(df_gemini1.0_pro_t1.0)+ggtitle("Gemini 1.0 Pro")

rank_abundance_plot(df_gemini1.5_pro_t1.0)+ggtitle("Gemini 1.5 Pro")

ICD converted responses

rank_abundance_plot(df_gpt3.5_icd)+ggtitle("ChatGPT 3.5 ICD")

rank_abundance_plot(df_gpt4.0_icd)+ggtitle("ChatGPT 4.0 ICD")

rank_abundance_plot(df_claude3_haiku_t1.0_icd)+ggtitle("Claude3 Haiku ICD")

rank_abundance_plot(df_claude3_opus_t1.0_icd)+ggtitle("Claude3 Opus ICD")

rank_abundance_plot(df_gemini1.0_pro_t1.0_icd)+ggtitle("Gemini 1.0 Pro ICD")

rank_abundance_plot(df_gemini1.5_pro_t1.0_icd)+ggtitle("Gemini 1.5 Pro ICD")

Combined model data

multi_ranked_abundance_plot(df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0, 
                            df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
                            df_gemini1.5_pro_t1.0)+
  ggtitle("Combined model rank abundance", "Original responses")
Warning: The `fun.y` argument of `stat_summary()` is deprecated as of ggplot2 3.3.0.
Please use the `fun` argument instead.Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
Please use `linewidth` instead.

multi_ranked_abundance_plot(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, 
                            df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, 
                            df_gemini1.5_pro_t1.0_icd)+
  ggtitle("Combined model rank abundance", "ICD converted responses")

Top diagnoses plots

custom_labeler <- function(x, wrap_width=33) {
    x %>%
        str_replace("___.+$", "") %>%
        str_wrap(width = wrap_width)
}

custom_text_formatting <- list(
  theme(axis.text = element_text(size = 7, lineheight = 0.7), 
          strip.text = element_text(size = 7),
          axis.title = element_text(size = 9)),
  tidytext::scale_x_reordered(labels = ~custom_labeler(., wrap_width = 45))
)
n_diag <- 25
sub <- "Original responses"
top_diagnosis_plot(df_gpt3.5, n_diag = n_diag)+ggtitle("ChatGPT 3.5", sub)

top_diagnosis_plot(df_gpt4.0, n_diag = n_diag)+ggtitle("ChatGPT 4.0", sub)

top_diagnosis_plot(df_claude3_haiku_t1.0, n_diag = n_diag)+ggtitle("Claude3 Haiku t1.0", sub)

top_diagnosis_plot(df_claude3_opus_t1.0, n_diag = n_diag)+ggtitle("Claude3 Opus t1.0", sub)

top_diagnosis_plot(df_gemini1.0_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.0 Pro", sub)

top_diagnosis_plot(df_gemini1.5_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.5 Pro", sub)

n_diag <- 25
sub <- "ICD converted responses"
top_diagnosis_plot(df_gpt3.5_icd, n_diag = n_diag) + custom_text_formatting + ggtitle("ChatGPT 3.5 ICD", sub) 
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_gpt4.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("ChatGPT 4.0 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_claude3_haiku_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Haiku t1.0 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_claude3_opus_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Opus t1.0 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_gemini1.0_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.0 Pro ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_gemini1.5_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.5 Pro ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

multi_top_diagnosis_plot(distribution_vis = "points", wrap_width=45, n_diag = 25,
                         df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0, 
                         df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
                         df_gemini1.5_pro_t1.0)

plt_diag_icd <- multi_top_diagnosis_plot(distribution_vis = "points", wrap_width = 33, n_diag = 15,
                         df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, 
                         df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, 
                         df_gemini1.5_pro_t1.0_icd) +
  guides(size = guide_legend(override.aes = list(size = 2)))

plt_diag_icd

plt_diag_icd$data %>% 
  summarise(freq=mean(freq),.by=c("criteria","diagnosis")) %>% 
  arrange(criteria, desc(freq))

Cumulative top frequency plots

sub <- "Original responses"
cumulative_frequency_plot(df_gpt3.5)$plot+ggtitle("GPT3", sub)

cumulative_frequency_plot(df_gpt4.0)$plot+ggtitle("GPT4", sub)

cumulative_frequency_plot(df_claude3_haiku_t1.0)$plot+ggtitle("Claude3 Haiku", sub)

cumulative_frequency_plot(df_claude3_opus_t1.0)$plot+ggtitle("Claude3 Haiku", sub)

cumulative_frequency_plot(df_gemini1.0_pro_t1.0)$plot+ggtitle("Gemini Pro 1.0", sub)

cumulative_frequency_plot(df_gemini1.5_pro_t1.0)$plot+ggtitle("Gemini Pro 1.5", sub)

sub <- "ICD converted responses"
cumulative_frequency_plot(df_gpt3.5_icd)$plot+ggtitle("GPT3 ICD", sub)

cumulative_frequency_plot(df_gpt4.0_icd)$plot+ggtitle("GPT4 ICD", sub)

cumulative_frequency_plot(df_claude3_haiku_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)

cumulative_frequency_plot(df_claude3_opus_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)

cumulative_frequency_plot(df_gemini1.0_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)

cumulative_frequency_plot(df_gemini1.5_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)

plt_freq <- multi_cumulative_frequency_plot(
  n_diagnoses = 25,
  distribution_vis = "points",
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
) +
  ggtitle("Original responses")

plt_freq

plt_freq$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
plt_freq_icd <- multi_cumulative_frequency_plot(
  n_diagnoses = 25,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  ggtitle("ICD converted responses")

plt_freq_icd

plt_freq_icd$data %>% summarise(freq = mean(total_frequency), .by = "criteria")

Diagnosis rank table

diagnosis_rank_table(df_gpt3.5, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gpt4.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_haiku_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_opus_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.0_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.5_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gpt3.5_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gpt4.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_haiku_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_opus_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.0_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.5_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
rank_table <-
  multi_diagnosis_rank_table(
    search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  )
rank_table  
rank_table %>% 
  flextable() %>% 
  width(width = 30) %>% 
  align(j = 2:3, align = "center", part = "all")

Diagnosis

MCAS - Consortium

MCAS - Alternative

T78.2 Anaphylactic shock, unspecified

1
[1, 1, 1, 1, 1, 1]

132
[216, 87, 99, 174, 141, 77]

D47.02 Systemic mastocytosis

9
[22, 8, 7, 2, 14, 2]

50
[92, 78, 25, 41, 46, 19]

D89.41 Monoclonal mast cell activation syndrome

70
[128, 22, 28, 11, 179, 51]

74
[141, 64, 22, 37, 168, 12]

D89.49 Other mast cell activation disorder

234
[308, 62, 101, 140, 478, 318]

625
[1155, 568, 178, 467, 1109, 275]

D89.4 Mast cell activation syndrome and related disorders

496
[726, 174, NA, NA, 605, 478]

1423
[NA, 833, 1850, 1594, 1906, 933]

Diversity

multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
)

plt_div_icd <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)

plt_div_icd

plt_div_icd$data %>% summarise(shannon=mean(shannon),.by="criteria")
extract_ggpubr_pvalues(plt_div_icd)  

Similarity

diagnosis_similarity_heatmap(df_gpt3.5, method = "bray")

diagnosis_similarity_heatmap(df_gpt4.0, method = "bray")

diagnosis_similarity_heatmap(df_claude3_haiku_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_claude3_opus_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_gpt3.5_icd, method = "bray")

diagnosis_similarity_heatmap(df_gpt4.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_claude3_haiku_t1.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_claude3_opus_t1.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0_icd, method = "bray")

multi_diagnosis_similarity_heatmap(
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0,
  method = "bray",
  show_dend = F,
  label_size = 6,
  title_size = 9
)

multi_diagnosis_similarity_heatmap(
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd,
    method = "bray",
  show_dend = F,
  label_size = 6,
  title_size = 9
)

  • Bray-Curtis similarity measures the similarity of a given diagnostic criteria’s set of alternative diagnoses along with their frequencies.
  • This demonstrates that SLE criteria results in a very similar set and frequency of diagnoses, while the diagnoses associated with two MCAS criteria are as different from each other as they are from those generated by the criteria of other conditions.

Multi Bray Curtis Similarity

calculate_similarity <- function(df){
  df <- df %>% 
    rename(model = original_df) %>% 
    count(model, criteria, diagnosis) %>% 
    unite(criteria, criteria, model)
  
  table(df$criteria, df$diagnosis) %>% 
  vegan::vegdist(method = "bray") %>% 
  as.matrix() %>% 
  {1-.} 
}
combine_data_frames(
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) %>% 
  calculate_similarity() %>% 
  model_criteria_heatmap(., 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Bray-Curtis\nsimilarity", 
                symmetric = F,
                font_size = 8) 

combine_data_frames(
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
) %>% 
  calculate_similarity() %>% 
  model_criteria_heatmap(., 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Bray-Curtis\nsimilarity", 
                symmetric = F,
                font_size = 8) 

PCA

diagnosis_pca_plot(df_gpt3.5) + ggtitle("GPT3")

diagnosis_pca_plot(df_gpt4.0) + ggtitle("GPT4")

diagnosis_pca_plot(df_claude3_haiku_t1.0) + ggtitle("Claude Haiku")

diagnosis_pca_plot(df_claude3_opus_t1.0) + ggtitle("Claude Opus")

diagnosis_pca_plot(df_gemini1.0_pro_t1.0) + ggtitle("Gemini")

diagnosis_pca_plot(df_gemini1.5_pro_t1.0) + ggtitle("Gemini")

df <- listN(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd) %>% 
  mapply(function(x,y) {mutate(x, model=y)}, ., names(.), SIMPLIFY = F) %>% 
  bind_rows() %>% 
  count(model, criteria, diagnosis) %>% 
  pivot_wider(names_from = "diagnosis", values_from = "n", values_fill = 0) %>% 
  unite(id, model, criteria, sep = "__") %>% 
  column_to_rownames("id") %>% 
  prcomp(scale. = F)

as.data.frame(df$x) %>% 
    rownames_to_column("id") %>% 
  separate(id, into = c("model", "criteria"), sep = "__") %>% 
  format_criteria() %>% 
  format_models() %>% 
  ggplot(aes(x = PC1, y = PC2, color = criteria))+
    geom_point()+
    # ggrepel::geom_label_repel() +
    theme_bw() +
  scale_color_brewer(palette = "Dark2")

Precision

  • Precision represents how similar each iteration of a 10-point differential diagnosis is with all other differential diagnoses from the same set of criteria.
  • I.e. how reproducible the 10-point differential diagnosis is for each criteria
  • Measured by obtaining the Bray-Curtis similarity values between all iterations within a criteria
# Script for calculating all Bray-Curtis similarity values within a criteria
# Found in source(here("scripts/diversity_analysis/calculate_precision.R"))
# Calculate precision
library(here)
source(here("utils/data_processing.R"))

models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>% 
  str_split("diagnoses_|_icd|.csv") %>% 
  sapply(., function(x) x[2]) %>% 
  unique()

use_icd <- TRUE

if (use_icd){models <- str_glue("{models}_icd")}

for (m in models){
  print(sprintf("READING IN DATA FOR: %s", m))
  read_path <- sprintf("data/processed_diagnoses/diagnoses_%s.csv.gz", m)
  df <- read_csv(here(read_path))
  
  print(sprintf("CALCULATING PRECISION FOR: %s", m))
  df <- calculate_precision(df)
  
  print(sprintf("WRITING PRECISION DATA FOR: %s", m))
  out_path <- sprintf("data/diversity_analysis/diagnosis_precision_%s.csv.gz", m)
  write_csv(df, here(out_path))
}
precision_dist_to_sim <- function(df){
  df %>% 
    mutate(
      mean = 1-mean,
      max = 1-min,
      min = 1-max
    )
}

plt_precision_icd <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis Similarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points")
Rows: 42 Columns: 8── Column specification ───────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (2): criteria, model
dbl (6): n, mean, max, min, sd, se
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.
plt_precision_icd

plt_precision_icd$data %>% summarise(mean = mean(mean), .by="criteria") 
extract_ggpubr_pvalues(plt_precision_icd) 

iNEXT

inext_plots <- function(inext_obj){
  for (i in 1:3){
    plt <- iNEXT::ggiNEXT(inext_obj, type=i, facet.var="Assemblage", color.var="Assemblage") +
      theme_classic() + 
      scale_color_brewer(palette = "Set1") +
      theme(axis.text.x = element_text(angle = 90))+
      scale_color_brewer(palette = "Dark2")
    print(plt)
  }
}

readRDS(here("data/diversity_analysis/mcas_iNEXT_gpt4_e250000.RDS")) %>% inext_plots()
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.

readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_gpt4_e200000.RDS")) %>% inext_plots()
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.

readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_psuedoMinus_gpt4_e200000.RDS")) %>% inext_plots()
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.

# custom_labeler <- function(x, wrap_width=33) {
#     x %>%
#         str_replace("___.+$", "") %>%
#         str_wrap(width = wrap_width)
# }

Final plot

Version 1

n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))
Warning: The `size` argument of `element_rect()` is deprecated as of ggplot2 3.4.0.
Please use the `linewidth` argument instead.
strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2)))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2))

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis\nSimilarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      plt_rank,
      plt_cumulative,
      plt_shannon,
      plt_precision,
      nrow = 1, 
      axis = 'tb',
      align = 'h',
      rel_widths = c(1, 0.7, 0.7, 0.7),
      labels = c(LETTERS[2:5]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  

full_plt

Version 2

n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2), nrow = 1))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = c(0.7,0.7))+
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 1)) +
  labs(color = NULL)
Warning: A numeric `legend.position` argument in `theme()` was deprecated in ggplot2 3.5.0.
Please use the `legend.position.inside` argument of `theme()` instead.
plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis\nSimilarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      plt_rank,
      plot_grid(
        plot_grid(
          plt_shannon+ theme(legend.position="none"),
          plt_precision+ theme(legend.position="none"),
          nrow = 1,
          axis = 'tb',
          align = 'h'
        ),
        get_legend(plt_shannon+ guides(color = guide_legend(row = 1))),
        ncol = 1,
        rel_heights = c(1,0.1)
      ),
      nrow = 1, 
      rel_widths = c(1,1),
      # labels = c(LETTERS[2:5]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  
Warning: Multiple components found; returning the first one. To return all, use `return_all = TRUE`.
full_plt

Version 3

n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad,
      b = legend_y_pad,
      l = legend_x_pad
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(5,"pt")
  )
)

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2), nrow = 1))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = c(0.7,0.7))+
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 1)) +
  labs(color = NULL)

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(
    # legend.position = "bottom", 
    # legend.direction = "horizontal", 
    axis.text.x = element_text(angle=90,hjust=1)) +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 90, hjust = 1))+
  labs(x="", y = "Mean Bray-Curtis Similarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_similarity <- multi_diagnosis_similarity_heatmap(
  method = "bray",
  show_dend = T,
  dendrogram_weight = unit(2.5, "mm"),
  legend_label = "Bray-Curtis similarity",
  legend_direction = "horizontal",
  label_size = 6,
  title_size = 9,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)

full_plt <- plot_grid(
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
    NULL,
    plot_grid(
    grid::grid.grabExpr(ComplexHeatmap::draw(plt_similarity, heatmap_legend_side = 'bottom')),
    NULL,ncol=1,rel_heights = c(1,0.1)
    ),
    plot_grid(
    plt_rank,
    NULL,ncol=1,rel_heights = c(1,0.1)
    ),
    plot_grid(
      plot_grid(
        plt_shannon+ theme(legend.position="none"),
        plt_precision+ theme(legend.position="none"),
        nrow = 1,
        axis = 'tb',
        align = 'h'
      ),
      plot_grid(NULL,get_legend(plt_shannon),nrow=1,rel_widths=c(0.2,1)),
      NULL,
      ncol = 1,
      rel_heights = c(1,0.05,0.1)
    ),
    nrow = 1, 
    rel_widths = c(0.1,0.8, 1,0.9),
    vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.07, 0.65)
)  

full_plt <- cowplot::ggdraw(full_plt)+cowplot::draw_plot_label(c("A","B","C","D","E"), x=c(0,0,0.35,0.67,0.83), y=c(1,0.38,0.38,0.38,0.38))
full_plt

  • Add a 2 column legend under D+E
ggsave(plot=full_plt,filename=here("figures/3_diagnosis_diversity.pdf"), width = 7.5, height = 8.5)

set_table_properties(opts_pdf = list(tabcolsep = 0))

set_flextable_defaults(fonts_ignore=TRUE)

multi_diagnosis_rank_table(search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
                                         df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd) %>% 
  flextable() %>% 
  width(width = 2) %>% 
  fontsize(size = 9) %>% 
  fontsize(size = 10, part = "header") %>% 
  padding(padding = 0) %>% 
  align(j = 2:3, align = "center", part = "all") %>% 
  set_table_properties(opts_pdf = list(arraystretch = 1.25)) %>% 
  {print(., preview = "pdf");.}

Diagnosis

MCAS - Consortium

MCAS - Alternative

T78.2 Anaphylactic shock, unspecified

1
[1, 1, 1, 1, 1, 1]

132
[216, 87, 99, 174, 141, 77]

D47.02 Systemic mastocytosis

9
[22, 8, 7, 2, 14, 2]

50
[92, 78, 25, 41, 46, 19]

D89.41 Monoclonal mast cell activation syndrome

70
[128, 22, 28, 11, 179, 51]

74
[141, 64, 22, 37, 168, 12]

D89.49 Other mast cell activation disorder

234
[308, 62, 101, 140, 478, 318]

625
[1155, 568, 178, 467, 1109, 275]

D89.4 Mast cell activation syndrome and related disorders

496
[726, 174, NA, NA, 605, 478]

1423
[NA, 833, 1850, 1594, 1906, 933]

---
title: "Diagnosis distribution analysis"
output: 
  html_notebook:
    toc: true
    toc_float: true
---

```{r, message = F}
library(here)
library(cowplot)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))
```

```{r}
all_models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>% 
  str_split("diagnoses_|_icd|.csv") %>% 
  sapply(., function(x) x[2]) %>% 
  unique()
all_models
```
# Import data

```{r, message = F}
df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
```

```{r, message = F}
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
```


# Rank abundance

**Original responses**
```{r}
rank_abundance_plot(df_gpt3.5)+ggtitle("ChatGPT 3.5")
rank_abundance_plot(df_gpt4.0)+ggtitle("ChatGPT 4.0")
rank_abundance_plot(df_claude3_haiku_t1.0)+ggtitle("Claude3 Haiku t1.0")
rank_abundance_plot(df_claude3_opus_t1.0)+ggtitle("Claude3 Opus")
rank_abundance_plot(df_gemini1.0_pro_t1.0)+ggtitle("Gemini 1.0 Pro")
rank_abundance_plot(df_gemini1.5_pro_t1.0)+ggtitle("Gemini 1.5 Pro")
```

**ICD converted responses**
```{r}
rank_abundance_plot(df_gpt3.5_icd)+ggtitle("ChatGPT 3.5 ICD")
rank_abundance_plot(df_gpt4.0_icd)+ggtitle("ChatGPT 4.0 ICD")
rank_abundance_plot(df_claude3_haiku_t1.0_icd)+ggtitle("Claude3 Haiku ICD")
rank_abundance_plot(df_claude3_opus_t1.0_icd)+ggtitle("Claude3 Opus ICD")
rank_abundance_plot(df_gemini1.0_pro_t1.0_icd)+ggtitle("Gemini 1.0 Pro ICD")
rank_abundance_plot(df_gemini1.5_pro_t1.0_icd)+ggtitle("Gemini 1.5 Pro ICD")
```

**Combined model data**

```{r}
multi_ranked_abundance_plot(df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0, 
                            df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
                            df_gemini1.5_pro_t1.0)+
  ggtitle("Combined model rank abundance", "Original responses")
```

```{r}
multi_ranked_abundance_plot(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, 
                            df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, 
                            df_gemini1.5_pro_t1.0_icd)+
  ggtitle("Combined model rank abundance", "ICD converted responses")
```

# Top diagnoses plots

```{r, fig.width = 12, fig.height = 8}
custom_labeler <- function(x, wrap_width=33) {
    x %>%
        str_replace("___.+$", "") %>%
        str_wrap(width = wrap_width)
}

custom_text_formatting <- list(
  theme(axis.text = element_text(size = 7, lineheight = 0.7), 
          strip.text = element_text(size = 7),
          axis.title = element_text(size = 9)),
  tidytext::scale_x_reordered(labels = ~custom_labeler(., wrap_width = 45))
)
```

```{r, fig.width = 16, fig.height = 8}
n_diag <- 25
sub <- "Original responses"
top_diagnosis_plot(df_gpt3.5, n_diag = n_diag)+ggtitle("ChatGPT 3.5", sub)
top_diagnosis_plot(df_gpt4.0, n_diag = n_diag)+ggtitle("ChatGPT 4.0", sub)
top_diagnosis_plot(df_claude3_haiku_t1.0, n_diag = n_diag)+ggtitle("Claude3 Haiku t1.0", sub)
top_diagnosis_plot(df_claude3_opus_t1.0, n_diag = n_diag)+ggtitle("Claude3 Opus t1.0", sub)
top_diagnosis_plot(df_gemini1.0_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.0 Pro", sub)
top_diagnosis_plot(df_gemini1.5_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.5 Pro", sub)
```

```{r, fig.width = 16, fig.height = 10}
n_diag <- 25
sub <- "ICD converted responses"
top_diagnosis_plot(df_gpt3.5_icd, n_diag = n_diag) + custom_text_formatting + ggtitle("ChatGPT 3.5 ICD", sub) 
top_diagnosis_plot(df_gpt4.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("ChatGPT 4.0 ICD", sub)
top_diagnosis_plot(df_claude3_haiku_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Haiku t1.0 ICD", sub)
top_diagnosis_plot(df_claude3_opus_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Opus t1.0 ICD", sub)
top_diagnosis_plot(df_gemini1.0_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.0 Pro ICD", sub)
top_diagnosis_plot(df_gemini1.5_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.5 Pro ICD", sub)
```

```{r, fig.width = 16, fig.height = 10}
multi_top_diagnosis_plot(distribution_vis = "points", wrap_width=45, n_diag = 25,
                         df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0, 
                         df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
                         df_gemini1.5_pro_t1.0)
```



```{r, fig.width = 16, fig.height = 10}
plt_diag_icd <- multi_top_diagnosis_plot(distribution_vis = "points", wrap_width = 33, n_diag = 15,
                         df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, 
                         df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, 
                         df_gemini1.5_pro_t1.0_icd) +
  guides(size = guide_legend(override.aes = list(size = 2)))

plt_diag_icd
plt_diag_icd$data %>% 
  summarise(freq=mean(freq),.by=c("criteria","diagnosis")) %>% 
  arrange(criteria, desc(freq))
```

# Cumulative top frequency plots


```{r, fig.width=3, fig.height=3.5}
sub <- "Original responses"
cumulative_frequency_plot(df_gpt3.5)$plot+ggtitle("GPT3", sub)
cumulative_frequency_plot(df_gpt4.0)$plot+ggtitle("GPT4", sub)
cumulative_frequency_plot(df_claude3_haiku_t1.0)$plot+ggtitle("Claude3 Haiku", sub)
cumulative_frequency_plot(df_claude3_opus_t1.0)$plot+ggtitle("Claude3 Haiku", sub)
cumulative_frequency_plot(df_gemini1.0_pro_t1.0)$plot+ggtitle("Gemini Pro 1.0", sub)
cumulative_frequency_plot(df_gemini1.5_pro_t1.0)$plot+ggtitle("Gemini Pro 1.5", sub)
```

```{r, fig.width=3, fig.height=3.5}
sub <- "ICD converted responses"
cumulative_frequency_plot(df_gpt3.5_icd)$plot+ggtitle("GPT3 ICD", sub)
cumulative_frequency_plot(df_gpt4.0_icd)$plot+ggtitle("GPT4 ICD", sub)
cumulative_frequency_plot(df_claude3_haiku_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)
cumulative_frequency_plot(df_claude3_opus_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)
cumulative_frequency_plot(df_gemini1.0_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)
cumulative_frequency_plot(df_gemini1.5_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)
```


```{r, fig.width=4, fig.height=3.5}
plt_freq <- multi_cumulative_frequency_plot(
  n_diagnoses = 25,
  distribution_vis = "points",
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
) +
  ggtitle("Original responses")

plt_freq
plt_freq$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
```

```{r, fig.width=4, fig.height=3.5}
plt_freq_icd <- multi_cumulative_frequency_plot(
  n_diagnoses = 25,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  ggtitle("ICD converted responses")

plt_freq_icd
plt_freq_icd$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
```

# Diagnosis rank table

```{r}
diagnosis_rank_table(df_gpt3.5, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gpt4.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_haiku_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_opus_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.0_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.5_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
```
```{r}
diagnosis_rank_table(df_gpt3.5_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gpt4.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_haiku_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_opus_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.0_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.5_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
```


```{r}
rank_table <-
  multi_diagnosis_rank_table(
    search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  )
rank_table  
```

```{r}
rank_table %>% 
  flextable() %>% 
  width(width = 30) %>% 
  align(j = 2:3, align = "center", part = "all")
```

# Diversity

```{r, fig.width=4, fig.height=3.5}
multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
)
```


```{r, fig.width=4, fig.height=3.5}
plt_div_icd <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)

plt_div_icd
plt_div_icd$data %>% summarise(shannon=mean(shannon),.by="criteria")
extract_ggpubr_pvalues(plt_div_icd)  
```

# Similarity

```{r, fig.width=4.25, fig.height=3.5}
diagnosis_similarity_heatmap(df_gpt3.5, method = "bray")
diagnosis_similarity_heatmap(df_gpt4.0, method = "bray")
diagnosis_similarity_heatmap(df_claude3_haiku_t1.0, method = "bray")
diagnosis_similarity_heatmap(df_claude3_opus_t1.0, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0, method = "bray")
```
```{r, fig.width=4.25, fig.height=3.5}
diagnosis_similarity_heatmap(df_gpt3.5_icd, method = "bray")
diagnosis_similarity_heatmap(df_gpt4.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_claude3_haiku_t1.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_claude3_opus_t1.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0_icd, method = "bray")
```

```{r, fig.width=4.25, fig.height=3.5}
multi_diagnosis_similarity_heatmap(
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0,
  method = "bray",
  show_dend = F,
  label_size = 6,
  title_size = 9
)
```

```{r, fig.width=4.25, fig.height=3.5}
multi_diagnosis_similarity_heatmap(
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd,
    method = "bray",
  show_dend = F,
  label_size = 6,
  title_size = 9
)
```
- Bray-Curtis similarity measures the similarity of a given diagnostic criteria’s set of alternative diagnoses along with their frequencies.
- This demonstrates that SLE criteria results in a very similar set and frequency of diagnoses, while the diagnoses associated with two MCAS criteria are as different from each other as they are from those generated by the criteria of other conditions.

#### Multi Bray Curtis Similarity

```{r}
calculate_similarity <- function(df){
  df <- df %>% 
    rename(model = original_df) %>% 
    count(model, criteria, diagnosis) %>% 
    unite(criteria, criteria, model)
  
  table(df$criteria, df$diagnosis) %>% 
  vegan::vegdist(method = "bray") %>% 
  as.matrix() %>% 
  {1-.} 
}
```

```{r, fig.width=6.5, fig.height=4.5}
combine_data_frames(
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) %>% 
  calculate_similarity() %>% 
  model_criteria_heatmap(., 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Bray-Curtis\nsimilarity", 
                symmetric = F,
                font_size = 8) 
```

```{r, fig.width=6.5, fig.height=4.5}
combine_data_frames(
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
) %>% 
  calculate_similarity() %>% 
  model_criteria_heatmap(., 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Bray-Curtis\nsimilarity", 
                symmetric = F,
                font_size = 8) 
```
### PCA

```{r, fig.width=4.25, fig.height=3.5}
diagnosis_pca_plot(df_gpt3.5) + ggtitle("GPT3")
diagnosis_pca_plot(df_gpt4.0) + ggtitle("GPT4")
diagnosis_pca_plot(df_claude3_haiku_t1.0) + ggtitle("Claude Haiku")
diagnosis_pca_plot(df_claude3_opus_t1.0) + ggtitle("Claude Opus")
diagnosis_pca_plot(df_gemini1.0_pro_t1.0) + ggtitle("Gemini")
diagnosis_pca_plot(df_gemini1.5_pro_t1.0) + ggtitle("Gemini")
```


```{r}
df <- listN(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd) %>% 
  mapply(function(x,y) {mutate(x, model=y)}, ., names(.), SIMPLIFY = F) %>% 
  bind_rows() %>% 
  count(model, criteria, diagnosis) %>% 
  pivot_wider(names_from = "diagnosis", values_from = "n", values_fill = 0) %>% 
  unite(id, model, criteria, sep = "__") %>% 
  column_to_rownames("id") %>% 
  prcomp(scale. = F)

as.data.frame(df$x) %>% 
    rownames_to_column("id") %>% 
  separate(id, into = c("model", "criteria"), sep = "__") %>% 
  format_criteria() %>% 
  format_models() %>% 
  ggplot(aes(x = PC1, y = PC2, color = criteria))+
    geom_point()+
    # ggrepel::geom_label_repel() +
    theme_bw() +
  scale_color_brewer(palette = "Dark2")
```

# Precision

- Precision represents how similar each iteration of a 10-point differential diagnosis is with all other differential diagnoses from the same set of criteria. 
- I.e. how reproducible the 10-point differential diagnosis is for each criteria
- Measured by obtaining the Bray-Curtis similarity values between all iterations within a criteria

```{r, eval=F}
# Script for calculating all Bray-Curtis similarity values within a criteria
# Found in source(here("scripts/diversity_analysis/calculate_precision.R"))
# Calculate precision
library(here)
source(here("utils/data_processing.R"))

models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>% 
  str_split("diagnoses_|_icd|.csv") %>% 
  sapply(., function(x) x[2]) %>% 
  unique()

use_icd <- TRUE

if (use_icd){models <- str_glue("{models}_icd")}

for (m in models){
  print(sprintf("READING IN DATA FOR: %s", m))
  read_path <- sprintf("data/processed_diagnoses/diagnoses_%s.csv.gz", m)
  df <- read_csv(here(read_path))
  
  print(sprintf("CALCULATING PRECISION FOR: %s", m))
  df <- calculate_precision(df)
  
  print(sprintf("WRITING PRECISION DATA FOR: %s", m))
  out_path <- sprintf("data/diversity_analysis/diagnosis_precision_%s.csv.gz", m)
  write_csv(df, here(out_path))
}
```

```{r, fig.width=4, fig.height=3.5}
precision_dist_to_sim <- function(df){
  df %>% 
    mutate(
      mean = 1-mean,
      max = 1-min,
      min = 1-max
    )
}

plt_precision_icd <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis Similarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points")

plt_precision_icd
plt_precision_icd$data %>% summarise(mean = mean(mean), .by="criteria") 
extract_ggpubr_pvalues(plt_precision_icd) 
```


# iNEXT

```{r, fig.width=12, fig.height=4}
inext_plots <- function(inext_obj){
  for (i in 1:3){
    plt <- iNEXT::ggiNEXT(inext_obj, type=i, facet.var="Assemblage", color.var="Assemblage") +
      theme_classic() + 
      scale_color_brewer(palette = "Set1") +
      theme(axis.text.x = element_text(angle = 90))+
      scale_color_brewer(palette = "Dark2")
    print(plt)
  }
}

readRDS(here("data/diversity_analysis/mcas_iNEXT_gpt4_e250000.RDS")) %>% inext_plots()
readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_gpt4_e200000.RDS")) %>% inext_plots()
readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_psuedoMinus_gpt4_e200000.RDS")) %>% inext_plots()
```



```{r, fig.width=7.4, fig.height=6.5}
# custom_labeler <- function(x, wrap_width=33) {
#     x %>%
#         str_replace("___.+$", "") %>%
#         str_wrap(width = wrap_width)
# }
```

# Final plot

### Version 1

```{r, fig.width=7.5, fig.height=8.5, message = F}
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2)))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2))

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis\nSimilarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      plt_rank,
      plt_cumulative,
      plt_shannon,
      plt_precision,
      nrow = 1, 
      axis = 'tb',
      align = 'h',
      rel_widths = c(1, 0.7, 0.7, 0.7),
      labels = c(LETTERS[2:5]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  

full_plt
```

### Version 2
```{r, fig.width=7.5, fig.height=8.5, message = F}
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2), nrow = 1))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = c(0.7,0.7))+
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 1)) +
  labs(color = NULL)

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis\nSimilarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      plt_rank,
      plot_grid(
        plot_grid(
          plt_shannon+ theme(legend.position="none"),
          plt_precision+ theme(legend.position="none"),
          nrow = 1,
          axis = 'tb',
          align = 'h'
        ),
        get_legend(plt_shannon+ guides(color = guide_legend(row = 1))),
        ncol = 1,
        rel_heights = c(1,0.1)
      ),
      nrow = 1, 
      rel_widths = c(1,1),
      # labels = c(LETTERS[2:5]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  

full_plt
```

### Version 3
```{r, fig.width=7.5, fig.height=8.5, message = F}
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad,
      b = legend_y_pad,
      l = legend_x_pad
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(5,"pt")
  )
)

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2), nrow = 1))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = c(0.7,0.7))+
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 1)) +
  labs(color = NULL)

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(
    # legend.position = "bottom", 
    # legend.direction = "horizontal", 
    axis.text.x = element_text(angle=90,hjust=1)) +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 90, hjust = 1))+
  labs(x="", y = "Mean Bray-Curtis Similarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_similarity <- multi_diagnosis_similarity_heatmap(
  method = "bray",
  show_dend = T,
  dendrogram_weight = unit(2.5, "mm"),
  legend_label = "Bray-Curtis similarity",
  legend_direction = "horizontal",
  label_size = 6,
  title_size = 9,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)

full_plt <- plot_grid(
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
    NULL,
    plot_grid(
    grid::grid.grabExpr(ComplexHeatmap::draw(plt_similarity, heatmap_legend_side = 'bottom')),
    NULL,ncol=1,rel_heights = c(1,0.1)
    ),
    plot_grid(
    plt_rank,
    NULL,ncol=1,rel_heights = c(1,0.1)
    ),
    plot_grid(
      plot_grid(
        plt_shannon+ theme(legend.position="none"),
        plt_precision+ theme(legend.position="none"),
        nrow = 1,
        axis = 'tb',
        align = 'h'
      ),
      plot_grid(NULL,get_legend(plt_shannon),nrow=1,rel_widths=c(0.2,1)),
      NULL,
      ncol = 1,
      rel_heights = c(1,0.05,0.1)
    ),
    nrow = 1, 
    rel_widths = c(0.1,0.8, 1,0.9),
    vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.07, 0.65)
)  

full_plt <- cowplot::ggdraw(full_plt)+cowplot::draw_plot_label(c("A","B","C","D","E"), x=c(0,0,0.35,0.67,0.83), y=c(1,0.38,0.38,0.38,0.38))
full_plt
```

- Add a 2 column legend under D+E

```{r, eval=F}
ggsave(plot=full_plt,filename=here("figures/3_diagnosis_diversity.pdf"), width = 7.5, height = 8.5)
```

set_table_properties(opts_pdf = list(tabcolsep = 0))
```{r}
set_flextable_defaults(fonts_ignore=TRUE)

multi_diagnosis_rank_table(search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
                                         df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd) %>% 
  flextable() %>% 
  width(width = 2) %>% 
  fontsize(size = 9) %>% 
  fontsize(size = 10, part = "header") %>% 
  padding(padding = 0) %>% 
  align(j = 2:3, align = "center", part = "all") %>% 
  set_table_properties(opts_pdf = list(arraystretch = 1.25)) %>% 
  {print(., preview = "pdf");.}
```


